{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZS8z-_KeywY9"
      },
      "source": [
        "# TF Lattice Canned Estimators"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/canned_estimators\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCpl-9WDVq9d"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Canned estimators are quick and easy ways to train TFL models for typical use cases. This guide outlines the steps needed to create a TFL canned estimator."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "FbZDk8bIx8ig"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import copy\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "from tensorflow import feature_column as fc\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svPuM6QNxlrH"
      },
      "source": [
        "Downloading the UCI Statlog (Heart) dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "j-k1qTR_yvBl"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "target = df.pop('target')\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_x = df[:train_size]\n",
        "train_y = target[:train_size]\n",
        "test_x = df[train_size:]\n",
        "test_y = target[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nKkAw12SxvGG"
      },
      "source": [
        "Setting the default values used for training in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "1T6GFI9F6mcG"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TGfzhPHzpix"
      },
      "source": [
        "## Feature Columns\n",
        "\n",
        "As for any other TF estimator, data needs to be passed to the estimator, which is typically via an input_fn and parsed using [FeatureColumns](https://www.tensorflow.org/guide/feature_columns)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DCIUz8apzs0l"
      },
      "outputs": [],
      "source": [
        "# Feature columns.\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar > 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "feature_columns = [\n",
        "    fc.numeric_column('age', default_value=-1),\n",
        "    fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n",
        "    fc.numeric_column('cp'),\n",
        "    fc.numeric_column('trestbps', default_value=-1),\n",
        "    fc.numeric_column('chol'),\n",
        "    fc.categorical_column_with_vocabulary_list('fbs', [0, 1]),\n",
        "    fc.categorical_column_with_vocabulary_list('restecg', [0, 1, 2]),\n",
        "    fc.numeric_column('thalach'),\n",
        "    fc.categorical_column_with_vocabulary_list('exang', [0, 1]),\n",
        "    fc.numeric_column('oldpeak'),\n",
        "    fc.categorical_column_with_vocabulary_list('slope', [0, 1, 2]),\n",
        "    fc.numeric_column('ca'),\n",
        "    fc.categorical_column_with_vocabulary_list(\n",
        "        'thal', ['normal', 'fixed', 'reversible']),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEZstmtT2CA3"
      },
      "source": [
        "TFL canned estimators use the type of the feature column to decide what type of calibration layer to use. We use a `tfl.layers.PWLCalibration` layer for numeric feature columns and a `tfl.layers.CategoricalCalibration` layer for categorical feature columns.\n",
        "\n",
        "Note that categorical feature columns are not wrapped by an embedding feature column. They are directly fed into the estimator."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H_LoW_9m5OFL"
      },
      "source": [
        "## Creating input_fn\n",
        "\n",
        "As for any other estimator, you can use an input_fn to feed data to the model for training and evaluation. TFL estimators can automatically calculate quantiles of the features and use them as input keypoints for the PWL calibration layer. To do so, they require passing a `feature_analysis_input_fn`, which is similar to the training input_fn but with a single epoch or a subsample of the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFVy1Efy5NKD"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    num_threads=1)\n",
        "\n",
        "# feature_analysis_input_fn is used to collect statistics about the input.\n",
        "feature_analysis_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    # Note that we only need one pass over the data.\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=test_x,\n",
        "    y=test_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "# Serving input fn is used to create saved models.\n",
        "serving_input_fn = (\n",
        "    tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "        feature_spec=fc.make_parse_example_spec(feature_columns)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQlzREcm2Wbj"
      },
      "source": [
        "## Feature Configs\n",
        "\n",
        "Feature calibration and per-feature configurations are set using `tfl.configs.FeatureConfig`. Feature configurations include monotonicity constraints, per-feature regularization (see `tfl.configs.RegularizerConfig`), and lattice sizes for lattice models.\n",
        "\n",
        "If no configuration is defined for an input feature, the default configuration in `tfl.config.FeatureConfig` is used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vD0tNpiO3p9c"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        # By default, input keypoints of pwl are quantiles of the feature.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_clip_max=100,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        # Keypoints can be uniformly spaced.\n",
        "        pwl_calibration_input_keypoints='uniform',\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        # Explicit input keypoint initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        monotonicity='increasing',\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) <= output(fixed)\n",
        "        # output(normal) <= output(reversible)        \n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LKBULveZ4mr3"
      },
      "source": [
        "## Calibrated Linear Model\n",
        "\n",
        "To construct a TFL canned estimator, construct a model configuration from `tfl.configs`. A calibrated linear model is constructed using `tfl.configs.CalibratedLinearConfig`. It applies piecewise-linear and categorical calibration on the input features, followed by a linear combination and an optional output piecewise-linear calibration. When using output calibration or when output bounds are specified, the linear layer will apply weighted averaging on calibrated inputs.\n",
        "\n",
        "This example creates a calibrated linear model on the first 5 features. We use\n",
        "`tfl.visualization` to plot the model graph with the calibrator plots."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "diRRozio4sAL"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the estimator.\n",
        "model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    use_bias=True,\n",
        "    output_calibration=True,\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated linear test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zWzPM2_p977t"
      },
      "source": [
        "## Calibrated Lattice Model\n",
        "\n",
        "A calibrated lattice model is constructed using `tfl.configs.CalibratedLatticeConfig`. A calibrated lattice model applies piecewise-linear and categorical calibration on the input features, followed by a lattice model and an optional output piecewise-linear calibration.\n",
        "\n",
        "This example creates a calibrated lattice model on the first 5 features.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C6EvVpKW4BbC"
      },
      "outputs": [],
      "source": [
        "# This is calibrated lattice model: Inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-4),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated lattice test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9494K_ZBKFcm"
      },
      "source": [
        "## Calibrated Lattice Ensemble\n",
        "\n",
        "When the number of features is large, you can use an ensemble model, which creates multiple smaller lattices for subsets of the features and averages their output instead of creating just a single huge lattice. Ensemble lattice models are constructed using `tfl.configs.CalibratedLatticeEnsembleConfig`. A calibrated lattice ensemble model applies piecewise-linear and categorical calibration on the input feature, followed by an ensemble of lattice models and an optional output piecewise-linear calibration.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KjrzziMFKuCB"
      },
      "source": [
        "### Random Lattice Ensemble\n",
        "\n",
        "The following model config uses a random subset of features for each lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YBSS7dLjKExq"
      },
      "outputs": [],
      "source": [
        "# This is random lattice ensemble model with separate calibration:\n",
        "# model output is the average output of separately calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7uyO8s97FGJM"
      },
      "source": [
        "### RTL Layer Random Lattice Ensemble\n",
        "\n",
        "The following model config uses a `tfl.layers.RTL` layer that uses a random subset of features for each lattice. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8v7dKg-FF7iz"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is RTL layer ensemble model with separate calibration:\n",
        "# model output is the average output of separately calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    lattices='rtl_layer',\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSXEaYAULRvf"
      },
      "source": [
        "### Crystals Lattice Ensemble\n",
        "\n",
        "TFL also provides a heuristic feature arrangement algorithm, called [Crystals](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices). The Crystals algorithm first trains a *prefitting model* that estimates pairwise feature interactions. It then arranges the final ensemble such that features with more non-linear interactions are in the same lattices.\n",
        "\n",
        "For Crystals models, you will also need to provide a `prefitting_input_fn` that is used to train the prefitting model, as described above. The prefitting model does not need to be fully trained, so a few epochs should be enough.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FjQKh9saMaFu"
      },
      "outputs": [],
      "source": [
        "prefitting_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=PREFITTING_NUM_EPOCHS,\n",
        "    num_threads=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fVnZpwX8MtPi"
      },
      "source": [
        "You can then create a Crystal model by setting `lattice='crystals'` in the model config."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f4awRMDe-eMv"
      },
      "outputs": [],
      "source": [
        "# This is Crystals ensemble model with separate calibration: model output is\n",
        "# the average output of separately calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    # prefitting_input_fn is required to train the prefitting model.\n",
        "    prefitting_input_fn=prefitting_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    prefitting_optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Crystals ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Isb2vyLAVBM1"
      },
      "source": [
        "You can plot feature calibrators with more details using the `tfl.visualization` module."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DJPaREuWS2sg"
      },
      "outputs": [],
      "source": [
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"age\")\n",
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"restecg\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "canned_estimators.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
